from typing import Dict
from unittest.mock import MagicMock

import pytest
import pytest_asyncio
from fastapi.middleware.cors import CORSMiddleware
from httpx import ASGITransport, AsyncClient

from dbgpt.component import SystemApp
from dbgpt.util import AppConfig
from dbgpt.util.fastapi import create_app
from dbgpt_serve.core import BaseServeConfig


def create_system_app(param: Dict) -> SystemApp:
    app_config = param.get("app_config", {})
    if isinstance(app_config, dict):
        app_config = AppConfig(configs=app_config)
    elif not isinstance(app_config, AppConfig):
        raise RuntimeError("app_config must be AppConfig or dict")

    test_app = create_app()
    test_app.add_middleware(
        CORSMiddleware,
        allow_origins=["*"],
        allow_credentials=True,
        allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
        allow_headers=["*"],
    )

    return SystemApp(test_app, app_config)


@pytest_asyncio.fixture
async def asystem_app(request):
    param = getattr(request, "param", {})
    return create_system_app(param)


@pytest.fixture
def system_app(request):
    param = getattr(request, "param", {})
    return create_system_app(param)


@pytest.fixture
def config():
    mock_config = MagicMock(spec=BaseServeConfig)
    mock_config.api_keys = "mock_api_key_123"
    mock_config.load_dbgpts_interval = 0
    mock_config.default_user = "dbgpt"
    mock_config.default_sys_code = "dbgpt"
    return mock_config


@pytest_asyncio.fixture
async def client(request, asystem_app: SystemApp, config: BaseServeConfig):
    param = getattr(request, "param", {})
    headers = param.get("headers", {})
    base_url = param.get("base_url", "http://test")
    client_api_key = param.get("client_api_key")
    routers = param.get("routers", [])
    app_caller = param.get("app_caller")
    if "api_keys" in param:
        del param["api_keys"]
    if client_api_key:
        headers["Authorization"] = "Bearer " + client_api_key

    test_app = asystem_app.app

    async with AsyncClient(
        transport=ASGITransport(test_app), base_url=base_url, headers=headers
    ) as client:
        for router in routers:
            test_app.include_router(router)
        if app_caller:
            app_caller(test_app, asystem_app, config)
        yield client
